# ------------------------------------------------------------------------------
# Measures performance of limited bandwidth models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=5000]" bash 03_summarize_performance.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(Metrics)
library(Matrix)
library(glue)
library(optparse)
library(MLmetrics)
library(dplyr)

temp <- here(
  "code", "06_physician_boundedness", "02_behavioral_gbms", "temp"
)
# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here::here("lib", "filepaths.yml"))
ids <- readRDS(glue(paths$analysis$test_cohort))
performance_list <- readRDS(file.path(temp, "performance_list.rds"))
performance_dsmall_list <- readRDS(file.path(temp, "performance_dsmall_list.rds"))

# Exclusions -------------------------------------------------------------------
message("Exclusions...")
keep_obs <- which(!ids$exclude)
ids <- ids[keep_obs, ] %>% setDT()

# Performance ------------------------------------------------------------------
message("Getting performance...")
performance_df <- tibble(
  model_num = 1:length(performance_list),
  perf = performance_list
) %>%
  unnest(perf) %>%
  group_by(niters, target_name, measure_name) %>%
  summarize(
    mean_performance = mean(performance),
    n = n(),
    std.deviation = sd(performance),
    std.error = std.deviation / sqrt(n)
  ) %>%
  select(-n, -std.deviation) %>%
  ungroup() %>%
  mutate(
    delta_lo = mean_performance - 1.96 * std.error,
    delta_hi = mean_performance + 1.96 * std.error
  )

performance_dsmall_df <- tibble(
  model_name = names(performance_dsmall_list),
  perf = performance_dsmall_list
) %>%
  unnest(perf) %>%
  mutate(depth = str_sub(model_name, -1, -1) %>% as.numeric()) %>%
  group_by(depth, measure_name, target_name) %>%
  summarize(
    mean_performance = mean(performance),
    n = n(),
    std.deviation = sd(performance),
    std.error = std.deviation / sqrt(n)
  ) %>%
  select(-n, -std.deviation) %>%
  ungroup() %>%
  mutate(
    delta_lo = mean_performance - 1.96 * std.error,
    delta_hi = mean_performance + 1.96 * std.error
  )

# Save -------------------------------------------------------------------------
message("Saving performance df...")
save_fp <- file.path(temp, "performance.rds")
write_rds(performance_df, save_fp)
save_fp_dsmall <- file.path(temp, "performance_dsmall.rds")
write_rds(performance_dsmall_df, save_fp_dsmall)

message("Done.")
